import time,sys,thread

from NumericLattice import *


class Numerics:
    def __init__(self, nSpinsX, nSpinsY, k, zArray, phaseArray, phiArray, phiScale, phiRatio, Ja, Jbx, Jby, Jad, dt, visuals):
        self.nSpinsX = nSpinsX
        self.nSpinsY = nSpinsY

        self.k = k
        self.zArray = zArray
        self.phaseArray = phaseArray
        self.phiArray = phiArray

        self.phiScale = phiScale
        self.phiRatio = phiRatio
        self.magS2 = 1.0

        self.stripeSpacingX = len(zArray)
        self.stripeSpacingY = len(zArray[0])

        self.Ja = Ja
        self.Jbx = Jbx
        self.Jby = Jby
        self.Jad = Jad

        self.t = 0.000
        self.dt = dt
        self.visuals = visuals

        self.tvUseHamiltonian = 0
        self.tvPause = 0
        self.n = -1
        self.baseSigma = 1.0

        self.EBu = 0.0
        self.EBd = 0.0
        self.temp = 0.0

        self.paused = 1
        self.pause = 1
        self.terminate = 0

        self.innerLoops = 64

        self.nLattice = NumericLattice(nSpinsX, nSpinsY, self.stripeSpacingX, self.stripeSpacingY)
        self.nLattice.setState(self.k, self.zArray, self.phaseArray, self.phiArray, self.phiScale, self.phiRatio, self.magS2)
        self.nLattice.setCouplings(self.Ja, self.Jbx, self.Jby, self.Jad)

        thread.start_new_thread(self.numericLoop,(self.innerLoops,))


    def numericLoop(self, innerLoops):
        print('-----------------------------')
        print("Time Integration Started")
        print(time.ctime())
        
        self.pause = 0
        x = 0
        
        while self.terminate != 1:
            self.paused = self.pause
            if self.paused != 1:
                self.t += self.dt

                self.nLattice.timeEvolve(self.t, self.dt, self.temp, self.EBu, self.EBd)
                
##                self.spinArray = self.nLattice.returnState()
##                self.energyArray = self.nLattice.returnEnergies()
##                self.bondEnergies = self.nLattice.returnBondEnergies()
##                self.meanFieldArray = self.nLattice.returnMeanFieldArray()
##                self.torqueArray = self.nLattice.returnTorqueArray()
##                x = x + 1
##                if x == 4:
                
                self.visuals.visualLoop(self.nLattice, self.t)



    def updateCouplings(self):
        self.hold()
        self.nLattice.setCouplings(self.Ja, self.Jbx, self.Jby, self.Jad)
        self.unhold()

    def updateState(self):
        self.hold()
#        self.nLattice.setState(self.k, self.zArray, self.phaseArray, self.phiArray, self.phiScale, self.phiRatio, self.magS2)
        self.nLattice.setSolitonState(self.k, self.zArray, self.phaseArray, self.phiArray, self.phiScale, self.phiRatio, self.magS2)

        self.unhold()


    def returnState(self):
        return self.nLattice.returnState()

    def returnMeanFieldArray(self):
        return self.nLattice.returnMeanFieldArray()

    def returnSpinSum(self):
        return self.nLattice.returnSpinSum()

    def returnEnergies(self):
        return self.nLattice.returnEnergies()

    def returnBondXEnergies(self):
        return self.nLattice.returnBondXEnergies()

    def returnTotalEnergy(self):
        return self.nLattice.returnTotalEnergy()

    def returnCouplingX(self):
        return self.nLattice.returnCouplingX()

    def returnCouplingY(self):
        return self.nLattice.returnCouplingY()

    def returnStripeSpacingX(self):
        return self.nLattice.returnStripeSpacingX()

    def returnStripeSpacingY(self):
        return self.nLattice.returnStripeSpacingY()
    

    def setJbx(self, Jbx):
        self.Jbx = float(Jbx)
        self.updateCouplings()
        self.updateState()

    def setJby(self, Jby):
        self.Jby = float(Jby)
        self.updateCouplings()
        self.updateState()
        
    def setPhiScale(self, PhiScale):
        phiScale = float(PhiScale)
        self.phiScale = phiScale
        self.updateState()


    def setPhiRatio(self, PhiRatio):
        phiRatio = float(PhiRatio)
        self.phiRatio = phiRatio
        self.updateState()
            
##    def setPhase(self, Phase):
##        phase = float(Phase)
##        self.phase = phase
##
##        self.updateState()


    def setMagS2(self, magS2):
        self.magS2 = float(magS2)
       
        self.updateState()

    def randomizeState(self):
        self.nLattice.randomizeState()

    def setEBu(self, EBu):
        self.EBu = float(EBu)

    def setEBd(self, EBd):
        self.EBd = float(EBd)

    def setTemp(self, temp):
        self.temp = float(temp)*abs(self.Ja)
        
    def randomize(self):
        pass
    
    def distort(self):
        pass


    
    def toggleUseHamiltonian(self):
        self.tvUseHamiltonian = (self.tvUseHamiltonian + 1)%2

    def setHamState(self, n):
        self.n = int(n)

    def setBaseSigma(self, baseSigma):
        self.baseSigma = float(baseSigma)

    def setDt(self, dt):
        self.dt = float(dt)/1000.0



    def freeze(self):
        self.tvPause = (self.tvPause+1)%2
        if self.tvPause == 1:
            self.hold()
        else:
            self.unhold()


    def hold(self):
        self.pause = 1
        while self.paused == 0: pass

    def unhold(self):
        self.pause = 0

    def reset(self):
        self.hold()
        self.t = 0.000
#        self.nLattice.setState(self.k, self.zArray, self.phaseArray, self.phiArray, self.phiScale, self.phiRatio, self.magS2)
        self.nLattice.setSolitonState(self.k, self.zArray, self.phaseArray, self.phiArray, self.phiScale, self.phiRatio, self.magS2)

        self.nLattice.setCouplings(self.Ja, self.Jbx, self.Jby, self.Jad)
        self.visuals.reset()
        self.unhold()
       
    def terminate(self):
        self.terminate = 1



    


    












##	self.tkr.scalePhiWidget = Scale(self.tkr.scalesFrame, orient=VERTICAL, length=160, from_=4.000, to=0.000, resolution=0.01, label="scale", command=lambda str: self.setScalePhi(str))
####        self.tkr.scalePhiWidget.set(scalePhi)
##        self.tkr.scalePhiWidget.grid(row=1, column=3)
        #


##    def setScalePhi(self, scalePhi):
####        self.scalePhi = float(scalePhi)
##        pass




##
##
##
##

##
##
##
##
##    def couplingDistortion(self):
##        scale = .1*self.Ja
##        for i in range(self.nSpinsX):
##            for j in range(self.nSpinsX):
##                randomV = random()
##                distortion = 2.0*scale*(randomV-.5)
##                if          i - j == 1 or j - i == 1:
##                    if      i%self.stripeSpacingX == 0 and j%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += distortion
##                    elif    j%self.stripeSpacingX == 0 and i%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += distortion
##                    else:                                                                                       self.couplingXArray[i][j] += distortion                  
##                if          i - j == self.nSpinsX - 1 or j - i == self.nSpinsX - 1:                             self.couplingXArray[i][j] += distortion
##
##        for i in range(self.nSpinsY):
##            for j in range(self.nSpinsY):
##                randomV = random()
##                distortion = 2.0*scale*(randomV-.5)
##                if          i - j == 1 or j - i == 1:
##                    if      i%self.stripeSpacingY == 0 and j%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += distortion
##                    elif    j%self.stripeSpacingY == 0 and i%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += distortion
##                    else:                                                                                       self.couplingYArray[i][j] += distortion
##                if          i - j == self.nSpinsY - 1 or j - i == self.nSpinsY - 1:                             self.couplingYArray[i][j] += distortion
##
##
##
##
##
##    def returnEnergies(self):
##        self.getEnergies()
##        return self.energyArray
##
##    def returnCommonAxis(self):
##        spinArray = self.nLattice.returnState()
##        commonAxis = vector(0.0,0.0,0.0)
##
##        TrackSpins = self.TrackSpins
##
##        for spin in TrackSpins:
##            commonAxis += spinArray[spin[0]][spin[1]][0]                  
##                    
##        return commonAxis
##
##    def getEnergies(self):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                self.energyArray[x][y] = -dot(self.spinArray[x][y][0],self.meanFieldArray[x][y][0])
##
##                if -self.energyArray[x][y] < -self.maxEnergyArray[x][y]:
##                    self.maxEnergyArray[x][y] = self.energyArray[x][y]
##                if self.energyArray[x][y] < self.minEnergyArray[x][y]:
##                    self.minEnergyArray[x][y] = self.energyArray[x][y]
##
##    def clearEnergies(self):
##        self.maxEnergyArray = -4.0*ones((self.nSpinsX,self.nSpinsY),Float32)
##        self.minEnergyArray = zeros((self.nSpinsX,self.nSpinsY),Float32)
##
##
##    def returnTotalEnergy(self):
##        energy = 0.0
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                energy = energy + self.energyArray[x][y]
##        return energy
##               
##    def returnSpinSum(self):
##        spinSum = vector(0.0,0.0,0.0)
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                spinSum += self.spinArray[x][y][0]
##        return spinSum
##
##

##    def moveSpin(self, spherePoint):
##        maxDot = -1.0
##        selectX = 0
##        selectY = 0
##
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                newDot = dot(spherePoint, self.spinArray[x][y][0])
##                if newDot > maxDot:
##                    maxDot = newDot
##                    selectX = x
##                    selectY = y
##
##        self.spinArray[selectX][selectY] = spherePoint


##      from Hamiltonian import *
        
##        if self.tvUseHamiltonian.get() == 1:
##            self.nLattice.setStateFromH(self.k, self.baseSigma, self.n)
##        else:
##            if self.baseSigma == 0.0:
##                self.nLattice.setStateManuallyPhi(self.k, self.scalePhi, self.zArray, self.phiArray, self.phaseArray)
##            else:
##                self.nLattice.setStateManually(self.k, self.baseSigma, self.zArray, self.sigmaArray, self.phaseArray)
##        self.nLattice.clearEnergies()
##        
##        self.nLattice.timeEvolve(self.t, self.dt, self.EBu, self.EBd, self.temp)

##        self.k = k
##        self.nTotSpins = nSpinsX*nSpinsY
##        self.nUpperLimit = self.nTotSpins
##        self.n = -1
        
##        self.Ham = Hamiltonian(stripeSpacingX,stripeSpacingY)
##        self.Ham.setzArray()
##        self.Ham.setCouplings(self.Ja, self.Jbx, self.Jby)
